# Copyright 2021-2024 The PySCF Developers. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
import pyscf
from pyscf.dft import rks
from pyscf.hessian.rks import _get_vnlc_deriv1, _get_vnlc_deriv1_numerical, \
                              _get_enlc_deriv2, _get_enlc_deriv2_numerical

# Expected errors against finite differences at dx=1e-3, are
# nlcgrids    error
# (10,14)     0.0023105784817920596
# (50,194)    0.00011128408988071454
# (75,302)    2.794393552631863e-06
# (99,590)    2.3908377588983953e-07
# (110,770)   1.3147625436667545e-07

def setUpModule():
    global mol

    atom = '''
    O  0.0000  0.7375 -0.0528
    O  0.0000 -0.7375 -0.1528
    H  0.8190  0.8170  0.4220
    H -0.8190 -0.8170  1.4220
    '''
    basis = 'def2-svp'

    mol = pyscf.M(atom=atom, basis=basis, max_memory=32000,
                  output='/dev/null', verbose=1)

def tearDownModule():
    global mol
    mol.stdout.close()
    del mol

def make_mf(mol, nlcgrid = (75,302), vv10_only = False, density_fitting = False):
    # Note: (75, 302) nlc grid is required to reduce error in de2 below 1e-5
    if not vv10_only:
        mf = rks.RKS(mol, xc = "wb97x-v")
        mf.grids.level = 5
    else:
        mf = rks.RKS(mol, xc = "0*PBE,0*PBE")
        mf.nlc = "vv10"
        mf.grids.atom_grid = (3,6)
    mf.conv_tol = 1e-13
    mf.direct_scf_tol = 1e-16
    mf.nlcgrids.atom_grid = nlcgrid
    mf.conv_tol_cpscf = 1e-10
    if density_fitting:
        mf = mf.density_fit(auxbasis = "def2-universal-jkfit")
    mf.kernel()
    assert mf.converged
    return mf

def numerical_d2enlc(mf):
    mol = mf.mol

    numerical_hessian = np.zeros([mol.natm, mol.natm, 3, 3])

    dx = 1e-3
    mol_copy = mol.copy()
    for i_atom in range(mol.natm):
        for i_xyz in range(3):
            xyz_p = mol.atom_coords()
            xyz_p[i_atom, i_xyz] += dx
            mol_copy.set_geom_(xyz_p, unit='Bohr')
            mol_copy.build()
            mf.reset(mol_copy)
            mf.kernel()
            assert mf.converged
            grad_obj = mf.Gradients()
            grad_obj.grid_response = True
            gradient_p = grad_obj.kernel()

            xyz_m = mol.atom_coords()
            xyz_m[i_atom, i_xyz] -= dx
            mol_copy.set_geom_(xyz_m, unit='Bohr')
            mol_copy.build()
            mf.reset(mol_copy)
            mf.kernel()
            assert mf.converged
            grad_obj = mf.Gradients()
            grad_obj.grid_response = True
            gradient_m = grad_obj.kernel()

            numerical_hessian[i_atom, :, i_xyz, :] = (gradient_p - gradient_m) / (2 * dx)
    mf.reset(mol)
    mf.kernel()

    np.set_printoptions(linewidth = np.iinfo(np.int32).max, threshold = np.iinfo(np.int32).max, precision = 16, suppress = True)
    print(repr(numerical_hessian))
    return numerical_hessian

def analytical_d2enlc(mf):
    hess_obj = mf.Hessian()
    hess_obj.auxbasis_response = 2
    analytical_hessian = hess_obj.kernel()
    return analytical_hessian

class KnownValues(unittest.TestCase):
    def test_vv10_only_hessian_direct_high_cost(self):
        mf = make_mf(mol, vv10_only = True)

        # reference_hessian = numerical_d2enlc(mf)
        reference_hessian = np.array([[[[ 0.5416385094555443,  0.0608587976822506,  0.4059780361467813],
         [ 0.0608583153386411,  0.2400708605971857,  0.0171074122129466],
         [ 0.4059767934618819,  0.0171075175856572,  0.0715012127059378]],

        [[ 0.0138411261745644, -0.0307587099088735, -0.0114409966949225],
         [-0.0046670364924131, -0.3662177810292988, -0.0328075691992114],
         [-0.0077329455397644, -0.0418507928570122,  0.0212319435189956]],

        [[-0.5632752360931192, -0.0127530395704345, -0.4031497682549512],
         [-0.0421637920360318,  0.1460901948911464, -0.0167164189056601],
         [-0.4027394645520488, -0.0044779744087786, -0.0910390554401674]],

        [[ 0.0077956004770896, -0.0173470482155158,  0.008612728790991 ],
         [-0.0140274868233869, -0.0199432745474626,  0.0324165758873729],
         [ 0.004495616644451 ,  0.02922124965829  , -0.0016941007819904]]],


       [[[ 0.0138412926078413, -0.0046674378371137, -0.007733128939813 ],
         [-0.0307585783373421, -0.3662178497929602, -0.0418510083978196],
         [-0.0114406839240022, -0.0328073191777634,  0.0212322387688757]],

        [[-0.0265688329405926,  0.0280390994733537,  0.0028061508691168],
         [ 0.0280390210395422,  0.4299653153587712,  0.0711125396311019],
         [ 0.0028060484083409,  0.0711122337611059, -0.0189041069124096]],

        [[ 0.0215131863576801, -0.0311007717287426,  0.0034433470949002],
         [-0.0042374598432371, -0.0574409077624405, -0.0113070426581707],
         [ 0.0060660526035594, -0.0238224669175113,  0.0092249375318598]],

        [[-0.0087856460534996,  0.0077291100670229,  0.0014836309728539],
         [ 0.0069570171381539, -0.0063065579256061, -0.017954488580163 ],
         [ 0.0025685829361799, -0.0144824477082139, -0.0115530693778898]]],


       [[[-0.5632753518693967, -0.0421644346566552, -0.402739948897668 ],
         [-0.0127529019099404,  0.1460905464591988, -0.0044775873883074],
         [-0.4031484872100144, -0.0167171608183025, -0.091038376287822 ]],

        [[ 0.0215132045501172, -0.0042371650881279,  0.0060659394314211],
         [-0.0311010481525466, -0.0574412135865288, -0.0238223321953335],
         [ 0.0034430790948267, -0.0113066360620806,  0.0092245405951541]],

        [[ 0.5449169029909662,  0.0467555396718167,  0.3963758752382196],
         [ 0.046755349856431 , -0.0877756999153601,  0.0269887824478898],
         [ 0.396375146332506 ,  0.0269891017112278,  0.0821096295715584]],

        [[-0.0031547556844091, -0.000353939952541 ,  0.0002981342411279],
         [-0.0029013998091298, -0.0008736328979408,  0.0013111371459651],
         [ 0.0033302617750142,  0.0010346951694329, -0.0002957938823878]]],


       [[[ 0.0077954733069818, -0.0140276105691228,  0.0044955368216359],
         [-0.0173472062745539, -0.0199430947489532,  0.0292210259493775],
         [ 0.0086125911367141,  0.0324162169498265, -0.0016944729875901]],

        [[-0.0087853560465749,  0.0069573606920059,  0.0025681631434793],
         [ 0.0077291657290882, -0.0063068020379475, -0.0144823954941753],
         [ 0.0014838511660648, -0.017954068078474 , -0.011553030224154 ]],

        [[-0.0031549510246531, -0.0029016744511612,  0.0033300169713923],
         [-0.0003539166857358, -0.0008736721611724,  0.0010348932112381],
         [ 0.0002979025925942,  0.0013112596001785, -0.0002961460349171]],

        [[ 0.0041448337520511,  0.009971924351676 , -0.0103937169317336],
         [ 0.0099719572323465,  0.027123568895393 , -0.0157735236657186],
         [-0.0103943449092925, -0.015773408461317 ,  0.013543649255765 ]]]])

        test_hessian = analytical_d2enlc(mf)

        assert np.linalg.norm(test_hessian - reference_hessian) < 1e-5

    def test_vv10_only_hessian_density_fitting_high_cost(self):
        mf = make_mf(mol, vv10_only = True, density_fitting = True)

        # reference_hessian = numerical_d2enlc(mf)
        reference_hessian = np.array([[[[ 0.5415690822132557,  0.0608562722286266,  0.4059126705860394],
         [ 0.0608570487260485,  0.2400788616032656,  0.0171129679309989],
         [ 0.4059109970324659,  0.0171147380978454,  0.0714620692536805]],

        [[ 0.0138241297472988, -0.0307645192521022, -0.011444424837026 ],
         [-0.0046682567246444, -0.3662416225749254, -0.0328128942426176],
         [-0.0077385738328807, -0.0418744339949484,  0.0212219265202096]],

        [[-0.5631978607103516, -0.0127426783599893, -0.4030859833951128],
         [-0.0421540206807514,  0.146110122013654 , -0.0167203915124592],
         [-0.4026860393344656, -0.0044790556974483, -0.090979895049581 ]],

        [[ 0.0078046487382299, -0.0173490746282756,  0.008617737641714 ],
         [-0.0140347713284972, -0.0199473610911771,  0.0324203178236893],
         [ 0.0045136161368475,  0.0292387516011017, -0.0017041007286944]]],


       [[[ 0.0138236935279812, -0.0046692417816629, -0.0077378711652587],
         [-0.0307656303105697, -0.3662417637855242, -0.0418725459100933],
         [-0.0114433360939303, -0.032812823865136 ,  0.0212207692167898]],

        [[-0.0265537883114599,  0.0280364159861435,  0.0027999313607641],
         [ 0.0280356838286977,  0.4300057094841492,  0.0711167561483483],
         [ 0.0028008877173101,  0.0711167590345951, -0.0188834771153168]],

        [[ 0.0215198898387836, -0.031100138694895 ,  0.0034520875638044],
         [-0.0042300795324302, -0.0574485161345395, -0.011293806742918 ],
         [ 0.006071742150171 , -0.0238195073950509,  0.0092259066575284]],

        [[-0.0087897950546978,  0.0077329644918023,  0.0014858522405237],
         [ 0.0069600260220737, -0.0063154296129908, -0.017950403495004 ],
         [ 0.0025707062291658, -0.0144844277661926, -0.0115631987543385]]],


       [[[-0.5631981291014387, -0.0421528096943291, -0.4026862782104956],
         [-0.0127431449786775,  0.1461110761713513, -0.0044783628956324],
         [-0.4030846823723788, -0.0167210497092896, -0.0909792552441502]],

        [[ 0.0215192912686873, -0.0042311047909749,  0.0060719560596167],
         [-0.0311007162406632, -0.0574491767011409, -0.0238194060092622],
         [ 0.0034516618868906, -0.0112928265325607,  0.0092247984134763]],

        [[ 0.5448428814074369,  0.0467446440063357,  0.3963146766036152],
         [ 0.0467451058514534, -0.0877858009699084,  0.0269845353028098],
         [ 0.3963148097061442,  0.0269848332785649,  0.0820484189942849]],

        [[-0.0031640435635971, -0.0003607295108454,  0.0002996455498172],
         [-0.0029012446319254, -0.0008760985011347,  0.0013132336018629],
         [ 0.0033182107773144,  0.0010290429619253, -0.0002939621596143]]],


       [[[ 0.0078054889197654, -0.0140356641091799,  0.0045126400246009],
         [-0.0173492152711896, -0.0199454559561829,  0.0292383102136751],
         [ 0.0086169191971797,  0.0324203680515112, -0.0017032086265245]],

        [[-0.0087886606801521,  0.0069605205768042,  0.0025708432059846],
         [ 0.0077325109240495, -0.0063174013167355, -0.014483500062612 ],
         [ 0.0014854317597936, -0.0179506621575953, -0.0115629838429721]],

        [[-0.0031656710521855, -0.0029017588178415,  0.0033178427655267],
         [-0.0003608430507729, -0.0008764673800621,  0.0010296243039276],
         [ 0.0002997056496312,  0.0013136119419999, -0.0002945709423052]],

        [[ 0.0041488428184633,  0.009976902353992 , -0.0104013259940028],
         [ 0.0099775473980102,  0.0271393246558116, -0.0157844344571556],
         [-0.0104020566066565, -0.015783317837248 ,  0.0135607634125789]]]])

        test_hessian = analytical_d2enlc(mf)

        assert np.linalg.norm(test_hessian - reference_hessian) < 2e-5

    def test_wb97xv_hessian_high_cost(self):
        mf = make_mf(mol, vv10_only = False, density_fitting = True)
        # reference_hessian = numerical_d2enlc(mf)
        reference_hessian = np.array([[[[ 0.4979170248502474,  0.0488882371119104,  0.2658377292182879],
         [ 0.0488888333068926,  0.1883207192108216, -0.0079990676912778],
         [ 0.2658379285943591, -0.0080001048310407,  0.1861260525712338]],

        [[-0.0518182468757095, -0.0367192982126952, -0.0084762016405726],
         [ 0.0114640970122204, -0.1366653643826155,  0.0052881575722807],
         [-0.0021614507570294, -0.0054693466973177, -0.0458043793829521]],

        [[-0.4501792270654725, -0.0004566340395251, -0.2609481898163679],
         [-0.043965400693402 , -0.0207535793726454, -0.0268411842651028],
         [-0.2622101640328278, -0.0026543058380124, -0.1453741557247978]],

        [[ 0.0040804490793467, -0.0117123048619661,  0.0035866622351555],
         [-0.0163875296323446, -0.0309017755172059,  0.0295520943831007],
         [-0.0014663138077076,  0.0161237573675088,  0.0050524825271903]]],


       [[[-0.0518200884548348,  0.0114620683238087, -0.0021632246499093],
         [-0.0367173716713243, -0.1366660962407451, -0.005471978241578 ],
         [-0.0084767030211763,  0.0052904095459994, -0.0458021925813235]],

        [[ 0.0605613287397999,  0.0148803610287018,  0.0369226818253132],
         [ 0.0148803931932923,  0.1864629046046673,  0.0350035881536703],
         [ 0.0369235508926313,  0.0349998659148198,  0.0153203814008407]],

        [[ 0.0065867310907741, -0.0361258279417132, -0.0023145476275577],
         [ 0.0062552829953599, -0.0364363880731022,  0.0001731457591747],
         [ 0.0037259297804848, -0.0226819064063077,  0.0014463435050738]],

        [[-0.015327971372936 ,  0.0097833985961693, -0.0324449095495116],
         [ 0.0155816954831023, -0.0133604203572946, -0.0297047556676033],
         [-0.0321727776542158, -0.0176083690487661,  0.0290354676804605]]],


       [[[-0.450177062162771 , -0.0439670468992404, -0.2622123128672715],
         [-0.0004572860455854, -0.0207532098732699, -0.0026525845637226],
         [-0.2609493509390104, -0.0268425226543911, -0.1453765479265123]],

        [[ 0.0065859967183085,  0.006258272889248 ,  0.003726127265069 ],
         [-0.0361284310268842, -0.036436495037151 , -0.0226831010925466],
         [-0.0023148937954784,  0.0001730991341375,  0.0014473155842687]],

        [[ 0.4442199147351999,  0.0378917754669805,  0.2597360978199292],
         [ 0.0378915918426426,  0.0580061764183792,  0.0246759551948417],
         [ 0.259734559926228 ,  0.0246775161092394,  0.1443710808441967]],

        [[-0.0006288492747086, -0.0001830014569604, -0.0012499122121756],
         [-0.0013058747669326, -0.0008164715054604,  0.0006597304641476],
         [ 0.0035296848096   ,  0.00199190740241  , -0.0004418484972346]]],


       [[[ 0.0040811043533484, -0.0163886407653635, -0.0014693834971546],
         [-0.0117128923626808, -0.0309014152932718,  0.016124835307052 ],
         [ 0.0035876584869587,  0.0295554798528386,  0.0050533075951487]],

        [[-0.0153259523469201,  0.0155893764368642, -0.0321715411738532],
         [ 0.0097833346091175, -0.0133608673431596, -0.0176091951796797],
         [-0.0324461302844831, -0.0297066942387403,  0.0290357279340014]],

        [[-0.0006252535320606, -0.0013102507717133,  0.0035306797017132],
         [-0.0001820687598464, -0.0008164011988665,  0.0019922401819361],
         [-0.0012508998239458,  0.000658272115539 , -0.0004428593658456]],

        [[ 0.0118701015253686,  0.0021095150880557,  0.0301102449726809],
         [ 0.0021116265142007,  0.0450786838410155, -0.0005078803117509],
         [ 0.0301093716240652, -0.0005070577182298, -0.0336461761628049]]]])

        test_hessian = analytical_d2enlc(mf)

        assert np.linalg.norm(test_hessian - reference_hessian) < 4e-4

    # If you wonder what is special about sto-6g? The answer is: It will trigger prune_by_density_() function
    # and remove some grids there.
    def test_wb97xv_sto6g_hessian_high_cost(self):
        mol_copy = mol.copy()
        mol_copy.basis = "sto-6g"
        mol_copy.build()
        mf = make_mf(mol_copy, vv10_only = False, density_fitting = True)

        # reference_hessian = numerical_d2enlc(mf)
        reference_hessian = np.array([[[[ 0.6336308259090595,  0.0573456704611175,  0.3625810477652647],
         [ 0.0573439018618505,  0.3182666549745861,  0.0059004173367794],
         [ 0.3625793744401751,  0.0058954672264022,  0.2139051350200094]],

        [[-0.0636687016642989, -0.0395097887926354, -0.0111143854187867],
         [ 0.0225302932161872, -0.2903368800994954, -0.0036306306906431],
         [-0.004465864577384 , -0.0261923132391928, -0.055590819173168 ]],

        [[-0.5743356486271889, -0.0028593979861657, -0.3563308539343835],
         [-0.0600891353197408,  0.0046906189776208, -0.0377451875586132],
         [-0.3587291728373021, -0.0038227902587895, -0.15923091729797  ]],

        [[ 0.0043735243854259, -0.0149764836818445,  0.0048641915896264],
         [-0.0197850597583038, -0.0326203939172154,  0.0354754009094238],
         [ 0.0006156629773768,  0.0241196362663898,  0.0009166014511841]]],


       [[[-0.0636664011941512,  0.022530169337287 , -0.0044681994587625],
         [-0.0395082509730971, -0.2903338271051936, -0.0261978020329456],
         [-0.0111154612613129, -0.0036278945136914, -0.0555865565051716]],

        [[ 0.075830416243601 ,  0.0136975375717441,  0.0441797131556232],
         [ 0.0136998482052134,  0.3581631537308283,  0.0541899923964806],
         [ 0.0441821724886104,  0.0541881110178721,  0.0220380469497794]],

        [[ 0.0088627112848627, -0.0472581671259187, -0.0025102066963933],
         [ 0.009537496212797 , -0.0408628785854015,  0.0025876847447037],
         [ 0.0046611743063085, -0.0290367572925998,  0.0018261042359358]],

        [[-0.0210267263351938,  0.0110304602121969, -0.0372013069985799],
         [ 0.0162709065620881, -0.026966448098964 , -0.0305798751083497],
         [-0.0377278855333563, -0.0215234592090552,  0.0317224053159038]]],


       [[[-0.5743368089190515, -0.0600885191639478, -0.358729193757068 ],
         [-0.0028612693276919,  0.0046920502780878, -0.0038226079453474],
         [-0.3563298158298922, -0.0377457546782978, -0.159233081433785 ]],

        [[ 0.0088643165882113,  0.0095385138880744,  0.0046626080222323],
         [-0.047259653162629 , -0.040861570954398 , -0.0290352027254581],
         [-0.0025099426689545,  0.0025892681296824,  0.0018263193690693]],

        [[ 0.5662338972167724,  0.0501690767570895,  0.3553844514642135],
         [ 0.0501703419448774,  0.0357379195037311,  0.0338090551490478],
         [ 0.3553846544752659,  0.0338098827545319,  0.1575966282131303]],

        [[-0.0007614048827542,  0.0003809285181455, -0.0013178657324864],
         [-0.0000494194573597,  0.0004316011768257, -0.0009512444764659],
         [ 0.0034551040241637,  0.0013466037849241, -0.0001898661470268]]],


       [[[ 0.0043737895226714, -0.019787015691719 ,  0.0006156213182007],
         [-0.0149770889950052, -0.032619209057394 ,  0.0241192794023237],
         [ 0.0048664261380615,  0.0354762772900585,  0.0009145229857288]],

        [[-0.0210255606446844,  0.0162741471081418, -0.0377278807552894],
         [ 0.0110331808902547, -0.0269675141693071, -0.0215222514396984],
         [-0.0372058410066725, -0.0305813428953527,  0.0317231549640251]],

        [[-0.0007628110116897, -0.0000521537331655,  0.0034563912642005],
         [ 0.0003806648771754,  0.0004308684665166,  0.0013480453870951],
         [-0.001317782241772 , -0.0009512145008328, -0.0001908001644457]],

        [[ 0.0174145821283944,  0.0035650223215722,  0.033655868171667 ],
         [ 0.0035632432306421,  0.0591558547581583, -0.0039450733491653],
         [ 0.0336571971085164, -0.0039437198910419, -0.0324468777884168]]]])

        test_hessian = analytical_d2enlc(mf)

        assert np.linalg.norm(test_hessian - reference_hessian) < 2e-4

    def test_vv10_only_hessian_direct(self):
        mf = make_mf(mol, nlcgrid=(6, 50), vv10_only = True)

        # reference_hessian = numerical_d2enlc(mf)
        reference_hessian = np.array([[[[ 0.5422389670473038,  0.0608849029992697,  0.4060492194686849],
         [ 0.0608669187469602,  0.2411577606364901,  0.0171918513361957],
         [ 0.4060462220221162,  0.0171775277699737,  0.0720223225585959]],

        [[ 0.0138905973397045, -0.0307118065654421, -0.0114078760189429],
         [-0.0046690720816207, -0.3668739312896463, -0.0329633165456755],
         [-0.0077408774765558, -0.0419989615298988,  0.0213116168268357]],

        [[-0.5639354954045661, -0.0127790746430878, -0.4032449109262481],
         [-0.0421630194709088,  0.1456179152686254, -0.0166743922528356],
         [-0.4028010528314141, -0.0044539502911944, -0.0915877247736896]],

        [[ 0.0078059310136824, -0.0173940217843283,  0.0086035674823903],
         [-0.0140348271938096, -0.0199017446258776,  0.0324458574623709],
         [ 0.0044957082873109,  0.0292753840491766, -0.0017462146124636]]],


       [[[ 0.0138980364737806, -0.0046642049422729, -0.0077435545464688],
         [-0.0306979373912997, -0.3668412227018081, -0.0420127914597668],
         [-0.0114147704678436, -0.0329573833637298,  0.0213225931645078]],

        [[-0.0266719779840469,  0.0280174462492511,  0.0027240503223114],
         [ 0.0280123127845838,  0.4306066618902094,  0.0713271116213887],
         [ 0.0027196766596155,  0.0713041965729744, -0.0188449660004641]],

        [[ 0.0215962253455881, -0.0310972935545495,  0.0034762614354866],
         [-0.0042761914516021, -0.0574536715330365, -0.0113406881657729],
         [ 0.0060599168305586, -0.0238334537607243,  0.0092279335832668]],

        [[-0.0088222838390273,  0.007744052248404 ,  0.001543242789892 ],
         [ 0.0069618160585816, -0.006311767673628 , -0.01797363199596  ],
         [ 0.002635176977378 , -0.0145133594611491, -0.0117055607486982]]],


       [[[-0.5639328402589072, -0.0421951998657022, -0.4028052479237987],
         [-0.0127782705844348,  0.1456589600197233, -0.0044589841909359],
         [-0.4032391759897225, -0.0166931419425254, -0.0915837547085241]],

        [[ 0.021608048656422 , -0.0042471804793109,  0.0060468203046948],
         [-0.0311080784461647, -0.0574848097185665, -0.0238100972493882],
         [ 0.0034819763398558, -0.0113282096636524,  0.0092186775002556]],

        [[ 0.5455197763624131,  0.0468034628006153,  0.3964345582322393],
         [ 0.0467970845283716, -0.0872851673504593,  0.0269840940250177],
         [ 0.3964321605172128,  0.0269819456438225,  0.0826381532419429]],

        [[-0.0031949847581481, -0.0003610824491074,  0.0003238693827012],
         [-0.0029107354926894, -0.0008889829466174,  0.0012849874124199],
         [ 0.0033250391236472,  0.0010394059546115, -0.0002730760335634]]],


       [[[ 0.0077956769360954, -0.0140570310049881,  0.004487601537928 ],
         [-0.0173948867028262, -0.0199117841974683,  0.0292754697035491],
         [ 0.0086099775260529,  0.0323876106653742, -0.0017621906810983]],

        [[-0.00881402280236  ,  0.0069704554839234,  0.0026284535072918],
         [ 0.0077466940723754, -0.0063043244020733, -0.0145185847585383],
         [ 0.0015608904519217, -0.0178983608414285, -0.0117255174202358]],

        [[-0.0031746446897962, -0.0029217742117393,  0.0033429861124867],
         [-0.0003581371461614, -0.0008892354524392,  0.0010405721901918],
         [ 0.0003158356047805,  0.0013191378129274, -0.000287083537609 ]],

        [[ 0.0041929905508775,  0.0100083497392434, -0.0104590411580952],
         [ 0.0100063297674979,  0.0271053440418778, -0.0157974571338149],
         [-0.0104867035888023, -0.0158083876474757,  0.013774791643717 ]]]])

        test_hessian = analytical_d2enlc(mf)

        assert abs(test_hessian - reference_hessian).max() < 6e-3

    def test_vv10_only_hessian_density_fitting(self):
        mf = make_mf(mol, nlcgrid=(6, 50), vv10_only = True, density_fitting = True)

        print('test_vv10_only_hessian_density_fitting')
        # reference_hessian = numerical_d2enlc(mf)
        reference_hessian = np.array([[[[ 0.542168430863188 ,  0.0609222841019275,  0.405984842392626 ],
         [ 0.0608663024931344,  0.2411645990134659,  0.0172038828970544],
         [ 0.4059818139032778,  0.0171681571874416,  0.071982739693055 ]],

        [[ 0.0138624531238994, -0.030753281953011 , -0.011391700686314 ],
         [-0.0046718942045276, -0.3668973349446603, -0.0329731695137836],
         [-0.0077417098879456, -0.0420075016576149,  0.0212935375722978]],

        [[-0.5638612124752029, -0.0127755047237832, -0.4031785618546779],
         [-0.0421554120917378,  0.1456378880321796, -0.0166844553116663],
         [-0.4027474975891732, -0.0044523772413929, -0.0915300473287584]],

        [[ 0.0078303284871684, -0.0173934974137535,  0.008585420150975 ],
         [-0.0140389961968967, -0.0199051521382332,  0.032453741928784 ],
         [ 0.0045073935751039,  0.0292917217102895, -0.0017462299338744]]],


       [[[ 0.0138627725791984, -0.0044830416854325, -0.0077689479912046],
         [-0.0307033815374336, -0.3668782336108123, -0.0420316349443173],
         [-0.011416036342915 , -0.0329680821886669,  0.0213094739394681]],

        [[-0.0267047939085836,  0.0278498880152966,  0.0028205881821286],
         [ 0.0280123104940827,  0.4306583716680024,  0.0713229735657128],
         [ 0.002716706747629 ,  0.0713137343293369, -0.0188255899165046]],

        [[ 0.0216086238817859, -0.031129264683738 ,  0.0035280544075089],
         [-0.0042696833535416, -0.0574589407569825, -0.0113311991081222],
         [ 0.0060647834763117, -0.02382966993697  ,  0.0092297279771225]],

        [[-0.0087666025508359,  0.0077624183455471,  0.0014203053994022],
         [ 0.0069607543997652, -0.0063211972930188, -0.0179601395100537],
         [ 0.0026345461165422, -0.01451598221619  , -0.0117136120006411]]],


       [[[-0.5638547756169343, -0.0421853018996998, -0.4027489367187109],
         [-0.0127656104511598,  0.1456436572109254, -0.0044564700579786],
         [-0.4031782977184095, -0.0166777303007848, -0.0915248894605147]],

        [[ 0.0216136396733041, -0.0042396699795333,  0.0060505042513981],
         [-0.0310972333722112, -0.0574617112567566, -0.0238279645358164],
         [ 0.0034860909206118, -0.0113321365784458,  0.0092292669409133]],

        [[ 0.5454448936412781,  0.0467928278031216,  0.3963708627952034],
         [ 0.0467864996647194, -0.0872893943619224,  0.026974441280414 ],
         [ 0.3963719738691029,  0.0269737337457165,  0.0825801313625307]],

        [[-0.0032037577029076, -0.0003678559157838,  0.0003275696731087],
         [-0.0029236558411472, -0.000892551613535 ,  0.0013099933099392],
         [ 0.0033202329213777,  0.0010361331211073, -0.0002845088384884]]],


       [[[ 0.0078053031556946, -0.0140306109983612,  0.0045168348619118],
         [-0.0174043414807867, -0.0197821597556214,  0.0292915769945501],
         [ 0.0086166148194344,  0.0324541742315887, -0.0017644739562228]],

        [[-0.008832216989163 ,  0.0069423500521637,  0.0026415983758188],
         [ 0.0077126990719739, -0.0064333042244513, -0.0144527745323808],
         [ 0.00154617825978  , -0.0179540412803547, -0.0117113252550149]],

        [[-0.0031904972070951, -0.0029275652797534,  0.0033219078221514],
         [-0.0003700651405314, -0.000914282871356 ,  0.0010491373123589],
         [ 0.0003082802009224,  0.0013125269376357, -0.000284007220519 ]],

        [[ 0.0042174110524429,  0.010015826224008 , -0.0104803410539422],
         [ 0.0100617075471376,  0.0271297468399656, -0.0158879397739731],
         [-0.0104710732851121, -0.0158126598952257,  0.0137598064317013]]]])

        test_hessian = analytical_d2enlc(mf)

        assert abs(test_hessian - reference_hessian).max() < 6e-3

    def test_wb97xv_hessian(self):
        mf = make_mf(mol, nlcgrid=(6, 50), vv10_only = False, density_fitting = True)
        # reference_hessian = numerical_d2enlc(mf)
        reference_hessian = np.array([[[[ 0.4988631231107599,  0.0489672839414368,  0.2660577198785319],
         [ 0.0489734877394676,  0.1897255830431988, -0.0078101261571928],
         [ 0.2660599405173159, -0.0078137822629998,  0.1868350936873875]],

        [[-0.051771569530977 , -0.0366657150712157, -0.0084571181140358],
         [ 0.0115135376406936, -0.1374562090887821,  0.0051198866219959],
         [-0.0021482104426498, -0.0056229068539082, -0.0457791297491239]],

        [[-0.4511626969443405, -0.0005718768824048, -0.2611939416498066],
         [-0.0440763711142544, -0.0213677280480606, -0.0269206001783928],
         [-0.2624601358143241, -0.0027332207694353, -0.1460975133965592]],

        [[ 0.0040769871449409, -0.0117313768899785,  0.0036000488370935],
         [-0.0164145086085088, -0.0309006412363977,  0.0296110518733794],
         [-0.0014482905957292,  0.0161724903490201,  0.0050415106733204]]],


       [[[-0.0517810751464731,  0.0115134106239978, -0.0021413970734963],
         [-0.0366680715491174, -0.137456690429083 , -0.0056334229160959],
         [-0.0084643006763585,  0.0051188671124081, -0.045771603517375 ]],

        [[ 0.0605579536767445,  0.0148391214169763,  0.0368378070428887],
         [ 0.0148416118147185,  0.1872218688099281,  0.0352017457353471],
         [ 0.0368422860602124,  0.03519618041814  ,  0.015470289529862 ]],

        [[ 0.0066125287137009, -0.0361419880820502, -0.0023317625458263],
         [ 0.0062373106164415, -0.0364063830698425,  0.0001598393409852],
         [ 0.0037036923532519, -0.0227040970159353,  0.001451774162442 ]],

        [[-0.0153913803045969,  0.0097899022819259, -0.0323681063152392],
         [ 0.0155889141273957, -0.0133553522042584, -0.0297307874967956],
         [-0.0320854087921818, -0.0176186038926551,  0.0288496001697447]]],


       [[[-0.451159345403962 , -0.0440801979015859, -0.2624618870713524],
         [-0.0005596911450967, -0.0213675101683464, -0.0027215831065464],
         [-0.2612002796613666, -0.0269246484547736, -0.1460934137018088]],

        [[ 0.0066130286903998,  0.0062411816443841,  0.003706422429528 ],
         [-0.0361560863584784, -0.0364033278010822, -0.0227120261816482],
         [-0.0023337153170974,  0.0001573468182414,  0.001453759111425 ]],

        [[ 0.4451884345406665,  0.0380248152018758,  0.2600027132105831],
         [ 0.0380226037735021,  0.058595325462707 ,  0.0247693793083981],
         [ 0.2600041998666036,  0.0247752873194784,  0.1450814677061985]],

        [[-0.0006382995647036, -0.0001875031252485, -0.001245653350801 ],
         [-0.001302387582558 , -0.0008200772848854,  0.0006668422152245],
         [ 0.0035264024371151,  0.0019895196946285, -0.0004391935750081]]],


       [[[ 0.0040892562166439, -0.0164117955323562, -0.0014552053189121],
         [-0.0117341212848565, -0.0308982989043471,  0.0161749134836575],
         [ 0.0036021915061912,  0.0296147042568862,  0.0050379575677328]],

        [[-0.015389526193503 ,  0.0156001358382341, -0.0320831071690675],
         [ 0.009787943389114 , -0.0133567006015411, -0.0176186446311233],
         [-0.0323661340360867, -0.0297316096826705,  0.0288474367043934]],

        [[-0.0006373766803236, -0.0013134771926215,  0.0035267654011628],
         [-0.000186969273841 , -0.0008195549733858,  0.0019887957584119],
         [-0.0012462138236913,  0.0006628131786202, -0.0004407752141056]],

        [[ 0.0119453748912424,  0.0021267717670859,  0.0300156238953186],
         [ 0.0021290136149599,  0.0450750641534936, -0.0005457257277741],
         [ 0.0300147763035441, -0.0005453235773223, -0.0334444976666792]]]])

        test_hessian = analytical_d2enlc(mf)

        assert abs(test_hessian - reference_hessian).max() < 4e-3

    def test_vv10_energy_second_derivative_high_cost(self):
        mf = make_mf(mol, vv10_only = True, density_fitting = True)
        hess_obj = mf.Hessian()

        reference_de2 = _get_enlc_deriv2_numerical(hess_obj, mf.mo_coeff, mf.mo_occ, max_memory = 2000)
        test_de2 = _get_enlc_deriv2(hess_obj, mf.mo_coeff, mf.mo_occ, max_memory = 2000)

        assert np.linalg.norm(test_de2 - reference_de2) < 1e-5

    def test_vv10_fock_first_derivative(self):
        mf = make_mf(mol, vv10_only = True, density_fitting = True, nlcgrid = (10,14))
        hess_obj = mf.Hessian()

        reference_dF = _get_vnlc_deriv1_numerical(hess_obj, mf.mo_coeff, mf.mo_occ, max_memory = 2000)
        test_dF = _get_vnlc_deriv1(hess_obj, mf.mo_coeff, mf.mo_occ, max_memory = 2000)

        assert np.linalg.norm(test_dF - reference_dF) < 5e-8

if __name__ == "__main__":
    print("Full Tests for RKS Hessian with VV10")
    unittest.main()
